{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Python使用tensorflow读取csv训练DNN模型\n",
    "\n",
    "本次分享到要点：\n",
    "1. 使用tf.data.experimental.make_csv_dataset函数，直接将csv数据读取到tf.data.Dataset\n",
    "2. 使用tf.feature_column.categorical_column_with_vocabulary_list，可以设置怎样读取类别特征；\n",
    "3. 使用tf.feature_column.numeric_column函数，可以设置怎样读取数字特征；\n",
    "4. 使用tf.keras.Sequential，可以搭建一个keras的dnn模型；\n",
    "5. 对于keras的model，可以用model.fit进行训练，使用model.evaluate进行准确率评估，使用model.predict进行新样本预测"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fgZ9gjmPfSnK"
   },
   "source": [
    "### 1. 导入所需的包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "baYFZMW_bJHh"
   },
   "outputs": [],
   "source": [
    "import functools\n",
    "\n",
    "import numpy as np\n",
    "np.set_printoptions(precision=3, suppress=True)\n",
    "\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.0.0'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Ncf5t6tgL5ZI"
   },
   "outputs": [],
   "source": [
    "train_file_path = \"./datas/titanic/train.csv\"\n",
    "test_file_path = \"./datas/titanic/test.csv\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Wuqj601Qw0Ml"
   },
   "source": [
    "### 2. 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "iXROZm5f3V4E"
   },
   "outputs": [],
   "source": [
    "# 标签列\n",
    "LABEL_COLUMN = 'survived'\n",
    "LABELS = [0, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Co7UJ7gpNADC"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /Users/peishuaishuai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/data/experimental/ops/readers.py:521: parallel_interleave (from tensorflow.python.data.experimental.ops.interleave_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, num_parallel_calls=tf.data.experimental.AUTOTUNE)` instead. If sloppy execution is desired, use `tf.data.Options.experimental_determinstic`.\n"
     ]
    }
   ],
   "source": [
    "def get_dataset(file_path):\n",
    "    \"\"\"\n",
    "    构建tensorflow的数据集格式\n",
    "    \"\"\"\n",
    "    dataset = tf.data.experimental.make_csv_dataset(\n",
    "      file_path,\n",
    "      batch_size=12,\n",
    "      label_name=LABEL_COLUMN,\n",
    "      na_value=\"?\",\n",
    "      num_epochs=1,\n",
    "      ignore_errors=True)\n",
    "    return dataset\n",
    "\n",
    "# 将train和test的csv，分别加载成tensorflow的对象的格式\n",
    "raw_train_data = get_dataset(train_file_path)\n",
    "raw_test_data = get_dataset(test_file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "qWtFYtwXIeuj"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EXAMPLES: \n",
      " OrderedDict([('sex', <tf.Tensor: id=176, shape=(12,), dtype=string, numpy=\n",
      "array([b'male', b'female', b'female', b'female', b'female', b'male',\n",
      "       b'male', b'male', b'female', b'female', b'female', b'male'],\n",
      "      dtype=object)>), ('age', <tf.Tensor: id=168, shape=(12,), dtype=float32, numpy=\n",
      "array([35., 31., 44.,  4., 31., 24., 22., 28., 28.,  3., 18., 33.],\n",
      "      dtype=float32)>), ('n_siblings_spouses', <tf.Tensor: id=174, shape=(12,), dtype=int32, numpy=array([0, 1, 0, 1, 1, 0, 0, 8, 0, 1, 2, 1], dtype=int32)>), ('parch', <tf.Tensor: id=175, shape=(12,), dtype=int32, numpy=array([0, 0, 0, 1, 1, 0, 0, 2, 0, 2, 2, 1], dtype=int32)>), ('fare', <tf.Tensor: id=173, shape=(12,), dtype=float32, numpy=\n",
      "array([  7.896,  18.   ,  27.721,  23.   ,  26.25 ,   7.496,   9.   ,\n",
      "        69.55 ,  33.   ,  41.579, 262.375,  20.525], dtype=float32)>), ('class', <tf.Tensor: id=170, shape=(12,), dtype=string, numpy=\n",
      "array([b'Third', b'Third', b'First', b'Second', b'Second', b'Third',\n",
      "       b'Third', b'Third', b'Second', b'Second', b'First', b'Third'],\n",
      "      dtype=object)>), ('deck', <tf.Tensor: id=171, shape=(12,), dtype=string, numpy=\n",
      "array([b'unknown', b'unknown', b'B', b'unknown', b'unknown', b'unknown',\n",
      "       b'unknown', b'unknown', b'unknown', b'unknown', b'B', b'unknown'],\n",
      "      dtype=object)>), ('embark_town', <tf.Tensor: id=172, shape=(12,), dtype=string, numpy=\n",
      "array([b'Cherbourg', b'Southampton', b'Cherbourg', b'Southampton',\n",
      "       b'Southampton', b'Southampton', b'Southampton', b'Southampton',\n",
      "       b'Southampton', b'Cherbourg', b'Cherbourg', b'Southampton'],\n",
      "      dtype=object)>), ('alone', <tf.Tensor: id=169, shape=(12,), dtype=string, numpy=\n",
      "array([b'y', b'n', b'y', b'n', b'n', b'y', b'y', b'n', b'y', b'n', b'n',\n",
      "       b'n'], dtype=object)>)]) \n",
      "\n",
      "LABELS: \n",
      " tf.Tensor([0 0 1 1 1 0 0 0 1 1 1 0], shape=(12,), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "# 测试一个批次\n",
    "examples, labels = next(iter(raw_train_data))\n",
    "print(\"EXAMPLES: \\n\", examples, \"\\n\")\n",
    "print(\"LABELS: \\n\", labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "9cryz31lxs3e"
   },
   "source": [
    "### 3. 数据预处理\n",
    "\n",
    "机器学习模型的输入，只能是数字"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tSyrkSQwYHKi"
   },
   "source": [
    "#### 分类数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "mWDniduKMw-C"
   },
   "outputs": [],
   "source": [
    "# 分类数据的码表\n",
    "CATEGORIES = {\n",
    "    'sex': ['male', 'female'],\n",
    "    'class' : ['First', 'Second', 'Third'],\n",
    "    'deck' : ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'],\n",
    "    'embark_town' : ['Cherbourg', 'Southhampton', 'Queenstown'],\n",
    "    'alone' : ['y', 'n']\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kkxLdrsLwHPT"
   },
   "outputs": [],
   "source": [
    "categorical_columns = []\n",
    "for feature, vocab in CATEGORIES.items():\n",
    "    # 提供码表的特征输入\n",
    "    cat_col = tf.feature_column.categorical_column_with_vocabulary_list(\n",
    "        key=feature, vocabulary_list=vocab)\n",
    "    categorical_columns.append(tf.feature_column.indicator_column(cat_col))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "H18CxpHY_Nma"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='sex', vocabulary_list=('male', 'female'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),\n",
       " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='class', vocabulary_list=('First', 'Second', 'Third'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),\n",
       " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='deck', vocabulary_list=('A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),\n",
       " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='embark_town', vocabulary_list=('Cherbourg', 'Southhampton', 'Queenstown'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),\n",
       " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='alone', vocabulary_list=('y', 'n'), dtype=tf.string, default_value=-1, num_oov_buckets=0))]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 分类特征列\n",
    "categorical_columns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "9AsbaFmCeJtF"
   },
   "source": [
    "#### 连续数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "REKqO_xHPNx0"
   },
   "outputs": [],
   "source": [
    "def process_continuous_data(mean, data):\n",
    "    # 标准化数据的函数\n",
    "    data = tf.cast(data, tf.float32) * 1/(2*mean)\n",
    "    return tf.reshape(data, [-1, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WKT1ASWpwH46"
   },
   "outputs": [],
   "source": [
    "# 提前算好的均值\n",
    "MEANS = {\n",
    "    'age' : 29.631308,\n",
    "    'n_siblings_spouses' : 0.545455,\n",
    "    'parch' : 0.379585,\n",
    "    'fare' : 34.385399\n",
    "}\n",
    "\n",
    "numerical_columns = []\n",
    "\n",
    "for feature in MEANS.keys():\n",
    "    num_col = tf.feature_column.numeric_column(\n",
    "        feature, normalizer_fn=functools.partial(process_continuous_data, MEANS[feature]))\n",
    "    numerical_columns.append(num_col)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Bw0I35xRS57V"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[NumericColumn(key='age', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=functools.partial(<function process_continuous_data at 0x7fc1f2700e60>, 29.631308)),\n",
       " NumericColumn(key='n_siblings_spouses', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=functools.partial(<function process_continuous_data at 0x7fc1f2700e60>, 0.545455)),\n",
       " NumericColumn(key='parch', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=functools.partial(<function process_continuous_data at 0x7fc1f2700e60>, 0.379585)),\n",
       " NumericColumn(key='fare', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=functools.partial(<function process_continuous_data at 0x7fc1f2700e60>, 34.385399))]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 连续特征列的列表\n",
    "numerical_columns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "DlF_omQqtnOP"
   },
   "source": [
    "### 4. 构建模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kPWkC4_1l3IG"
   },
   "source": [
    "#### 创建输入层layer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "R3QAjo1qD4p9"
   },
   "source": [
    "将这两个特征列的集合相加，并且传给 `tf.keras.layers.DenseFeatures` 从而创建一个进行预处理的输入层。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3-OYK7GnaH0r"
   },
   "outputs": [],
   "source": [
    "preprocessing_layer = tf.keras.layers.DenseFeatures(\n",
    "    categorical_columns+numerical_columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "lQoFh16LxtT_"
   },
   "source": [
    "从 `preprocessing_layer` 开始构建 `tf.keras.Sequential`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3mSGsHTFPvFo"
   },
   "outputs": [],
   "source": [
    "# 构建一个DNN模型h(g(f(x)))\n",
    "model = tf.keras.Sequential([\n",
    "    preprocessing_layer,\n",
    "    tf.keras.layers.Dense(64, activation='relu'),\n",
    "    tf.keras.layers.Dense(32, activation='relu'),\n",
    "    tf.keras.layers.Dense(16, activation='relu'),\n",
    "    tf.keras.layers.Dense(1, activation='sigmoid'),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(\n",
    "    loss='binary_crossentropy',\n",
    "    optimizer='adam',\n",
    "    metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hPdtI2ie0lEZ"
   },
   "source": [
    "### 5. 训练、评估和预测"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8gvw1RE9zXkD"
   },
   "source": [
    "现在可以实例化和训练模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "sW-4XlLeEQ2B"
   },
   "outputs": [],
   "source": [
    "train_data = raw_train_data.shuffle(500)\n",
    "test_data = raw_test_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Q_nm28IzNDTO"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /Users/peishuaishuai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/feature_column/feature_column_v2.py:4276: IndicatorColumn._variable_shape (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /Users/peishuaishuai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/feature_column/feature_column_v2.py:4331: VocabularyListCategoricalColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "Epoch 1/20\n",
      "53/53 [==============================] - 2s 32ms/step - loss: 0.5944 - accuracy: 0.7177\n",
      "Epoch 2/20\n",
      "53/53 [==============================] - 0s 5ms/step - loss: 0.4505 - accuracy: 0.7990\n",
      "Epoch 3/20\n",
      "53/53 [==============================] - 0s 4ms/step - loss: 0.4437 - accuracy: 0.8086\n",
      "Epoch 4/20\n",
      "53/53 [==============================] - 0s 4ms/step - loss: 0.4136 - accuracy: 0.8198\n",
      "Epoch 5/20\n",
      "53/53 [==============================] - 0s 4ms/step - loss: 0.4130 - accuracy: 0.8150\n",
      "Epoch 6/20\n",
      "53/53 [==============================] - 0s 5ms/step - loss: 0.3994 - accuracy: 0.8246\n",
      "Epoch 7/20\n",
      "53/53 [==============================] - 0s 5ms/step - loss: 0.3983 - accuracy: 0.8293\n",
      "Epoch 8/20\n",
      "53/53 [==============================] - 0s 4ms/step - loss: 0.3899 - accuracy: 0.8262\n",
      "Epoch 9/20\n",
      "53/53 [==============================] - 0s 5ms/step - loss: 0.3863 - accuracy: 0.8278\n",
      "Epoch 10/20\n",
      "53/53 [==============================] - 0s 4ms/step - loss: 0.3809 - accuracy: 0.8278\n",
      "Epoch 11/20\n",
      "53/53 [==============================] - 0s 4ms/step - loss: 0.3778 - accuracy: 0.8293\n",
      "Epoch 12/20\n",
      "53/53 [==============================] - 0s 4ms/step - loss: 0.3822 - accuracy: 0.8309\n",
      "Epoch 13/20\n",
      "53/53 [==============================] - 0s 5ms/step - loss: 0.3693 - accuracy: 0.8421\n",
      "Epoch 14/20\n",
      "53/53 [==============================] - 0s 4ms/step - loss: 0.3664 - accuracy: 0.8421\n",
      "Epoch 15/20\n",
      "53/53 [==============================] - 0s 5ms/step - loss: 0.3498 - accuracy: 0.8389\n",
      "Epoch 16/20\n",
      "53/53 [==============================] - 0s 6ms/step - loss: 0.3566 - accuracy: 0.8501\n",
      "Epoch 17/20\n",
      "53/53 [==============================] - 0s 5ms/step - loss: 0.3630 - accuracy: 0.8501\n",
      "Epoch 18/20\n",
      "53/53 [==============================] - 0s 5ms/step - loss: 0.3687 - accuracy: 0.8549\n",
      "Epoch 19/20\n",
      "53/53 [==============================] - 0s 5ms/step - loss: 0.3614 - accuracy: 0.8389\n",
      "Epoch 20/20\n",
      "53/53 [==============================] - 0s 6ms/step - loss: 0.3483 - accuracy: 0.8501\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fc1ea75c490>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(train_data, epochs=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fit?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_features (DenseFeature multiple                  0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                multiple                  1600      \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              multiple                  2080      \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              multiple                  528       \n",
      "_________________________________________________________________\n",
      "dense_3 (Dense)              multiple                  17        \n",
      "=================================================================\n",
      "Total params: 4,225\n",
      "Trainable params: 4,225\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QyDMgBurzqQo"
   },
   "source": [
    "当模型训练完成的时候，你可以在测试集 `test_data` 上检查准确性。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "eB3R3ViVONOp"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22/22 [==============================] - 1s 27ms/step - loss: 0.4375 - accuracy: 0.8068\n",
      "\n",
      "Test Loss 0.43752676384015515, Test Accuracy 0.8068181872367859\n"
     ]
    }
   ],
   "source": [
    "test_loss, test_accuracy = model.evaluate(test_data)\n",
    "\n",
    "print()\n",
    "print(f'Test Loss {test_loss}, Test Accuracy {test_accuracy}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "sTrn_pD90gdJ"
   },
   "source": [
    "使用 `tf.keras.Model.predict` 推断一个批次或多个批次的标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Qwcx74F3ojqe",
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.964],\n",
       "       [0.092],\n",
       "       [0.097],\n",
       "       [0.147],\n",
       "       [0.115],\n",
       "       [0.222],\n",
       "       [0.936],\n",
       "       [0.509],\n",
       "       [0.83 ],\n",
       "       [0.051]], dtype=float32)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = model.predict(test_data)\n",
    "predictions[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=5931, shape=(12,), dtype=int32, numpy=array([0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0], dtype=int32)>"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(test_data)[0][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "预测活着的概率: 0.964362382888794 | 实际值: DIED\n",
      "预测活着的概率: 0.0921531617641449 | 实际值: DIED\n",
      "预测活着的概率: 0.0965917706489563 | 实际值: DIED\n",
      "预测活着的概率: 0.14703363180160522 | 实际值: DIED\n",
      "预测活着的概率: 0.11535680294036865 | 实际值: SURVIVED\n",
      "预测活着的概率: 0.2217859923839569 | 实际值: SURVIVED\n",
      "预测活着的概率: 0.9363394975662231 | 实际值: SURVIVED\n",
      "预测活着的概率: 0.5085902810096741 | 实际值: DIED\n",
      "预测活着的概率: 0.829928994178772 | 实际值: DIED\n",
      "预测活着的概率: 0.05141010880470276 | 实际值: DIED\n"
     ]
    }
   ],
   "source": [
    "# 显示部分结果\n",
    "for prediction, survived in zip(predictions[:10], list(test_data)[0][1][:10]):\n",
    "    is_survived = \"SURVIVED\" if bool(survived) else \"DIED\"\n",
    "    print(f\"预测活着的概率: {prediction[0]} | 实际值: {is_survived}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "csv.ipynb",
   "private_outputs": true,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
